{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "622d49d1-5358-4321-bfc7-b976629c7754",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install --upgrade --quiet neo4j langchain-experimental langchain_community langchain langchain-google-vertexai"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "4a56e297-34a6-406d-95be-b78312c4d27c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"GOOGLE_APPLICATION_CREDENTIALS\"] = \"credentials.json\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a44ef26b-6296-4d2e-9614-d095b7f99fa3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Any, List, Optional, Sequence\n",
    "\n",
    "from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship\n",
    "from langchain_core.documents import Document\n",
    "from langchain_core.language_models import BaseLanguageModel\n",
    "from langchain_core.prompts import ChatPromptTemplate\n",
    "from langchain_core.pydantic_v1 import BaseModel, Field\n",
    "\n",
    "system_prompt = (\n",
    "    \"# Knowledge Graph Instructions for GPT-4\\n\"\n",
    "    \"## 1. Overview\\n\"\n",
    "    \"You are a top-tier algorithm designed for extracting information in structured \"\n",
    "    \"formats to build a knowledge graph.\\n\"\n",
    "    \"Try to capture as much information from the text as possible without \"\n",
    "    \"sacrifing accuracy. Do not add any information that is not explicitly \"\n",
    "    \"mentioned in the text\\n\"\n",
    "    \"- **Nodes** represent entities and concepts.\\n\"\n",
    "    \"- The aim is to achieve simplicity and clarity in the knowledge graph, making it\\n\"\n",
    "    \"accessible for a vast audience.\\n\"\n",
    "    \"## 2. Labeling Nodes\\n\"\n",
    "    \"- **Consistency**: Ensure you use available types for node labels.\\n\"\n",
    "    \"Ensure you use basic or elementary types for node labels.\\n\"\n",
    "    \"- For example, when you identify an entity representing a person, \"\n",
    "    \"always label it as **'person'**. Avoid using more specific terms \"\n",
    "    \"like 'mathematician' or 'scientist'\"\n",
    "    \"  - **Node IDs**: Never utilize integers as node IDs. Node IDs should be \"\n",
    "    \"names or human-readable identifiers found in the text.\\n\"\n",
    "    \"- **Relationships** represent connections between entities or concepts.\\n\"\n",
    "    \"Ensure consistency and generality in relationship types when constructing \"\n",
    "    \"knowledge graphs. Instead of using specific and momentary types \"\n",
    "    \"such as 'BECAME_PROFESSOR', use more general and timeless relationship types \"\n",
    "    \"like 'PROFESSOR'. Make sure to use general and timeless relationship types!\\n\"\n",
    "    \"## 3. Coreference Resolution\\n\"\n",
    "    \"- **Maintain Entity Consistency**: When extracting entities, it's vital to \"\n",
    "    \"ensure consistency.\\n\"\n",
    "    'If an entity, such as \"John Doe\", is mentioned multiple times in the text '\n",
    "    'but is referred to by different names or pronouns (e.g., \"Joe\", \"he\"),'\n",
    "    \"always use the most complete identifier for that entity throughout the \"\n",
    "    'knowledge graph. In this example, use \"John Doe\" as the entity ID.\\n'\n",
    "    \"Remember, the knowledge graph should be coherent and easily understandable, \"\n",
    "    \"so maintaining consistency in entity references is crucial.\\n\"\n",
    "    \"## 4. Strict Compliance\\n\"\n",
    "    \"Adhere to the rules strictly. Non-compliance will result in termination.\"\n",
    ")\n",
    "\n",
    "default_prompt = ChatPromptTemplate.from_messages(\n",
    "    [\n",
    "        (\n",
    "            \"system\",\n",
    "            system_prompt,\n",
    "        ),\n",
    "        (\n",
    "            \"human\",\n",
    "            (\n",
    "                \"Tip: Make sure to answer in the correct format and do \"\n",
    "                \"not include any explanations. \"\n",
    "                \"Use the given format to extract information from the \"\n",
    "                \"following input: {input}\"\n",
    "            ),\n",
    "        ),\n",
    "    ]\n",
    ")\n",
    "\n",
    "\n",
    "def optional_enum_field(\n",
    "    enum_values: Optional[List[str]] = None,\n",
    "    description: str = \"\",\n",
    "    is_rel: bool = False,\n",
    "    **field_kwargs: Any,\n",
    ") -> Any:\n",
    "    \"\"\"Utility function to conditionally create a field with an enum constraint.\"\"\"\n",
    "    if enum_values:\n",
    "        return Field(\n",
    "            ...,\n",
    "            enum=enum_values,\n",
    "            description=f\"{description}. Available options are {enum_values}\",\n",
    "            **field_kwargs,\n",
    "        )\n",
    "    else:\n",
    "        node_info = (\n",
    "            \"Ensure you use basic or elementary types for node labels.\\n\"\n",
    "            \"For example, when you identify an entity representing a person, \"\n",
    "            \"always label it as **'Person'**. Avoid using more specific terms \"\n",
    "            \"like 'Mathematician' or 'Scientist'\"\n",
    "        )\n",
    "        rel_info = (\n",
    "            \"Instead of using specific and momentary types such as \"\n",
    "            \"'BECAME_PROFESSOR', use more general and timeless relationship types like \"\n",
    "            \"'PROFESSOR'. However, do not sacrifice any accuracy for generality\"\n",
    "        )\n",
    "        additional_info = rel_info if is_rel else node_info\n",
    "        return Field(..., description=description + additional_info, **field_kwargs)\n",
    "\n",
    "\n",
    "def create_simple_model(\n",
    "    node_labels: Optional[List[str]] = None, rel_types: Optional[List[str]] = None\n",
    ") -> Any:\n",
    "    \"\"\"\n",
    "    Simple model allows to limit node and/or relationship types.\n",
    "    Doesn't have any node or relationship properties.\n",
    "    \"\"\"\n",
    "    class SimpleRelationship(BaseModel):\n",
    "        \"\"\"Represents a directed relationship between two nodes in a graph.\"\"\"\n",
    "        start_node_id: str = optional_enum_field(\n",
    "            node_labels, description=\"The type or label of the start node.\"\n",
    "        )\n",
    "        start_node_type: str = Field(description=\"Mandatory type of the start node of the relationship.\")\n",
    "        end_node_id: str = Field(description=\"Name or human-readable unique identifier of end node\")\n",
    "        end_node_type: str = optional_enum_field(\n",
    "            node_labels, description=\"The type or label of the end node.\"\n",
    "        )\n",
    "        type: str = optional_enum_field(\n",
    "            rel_types, description=\"The type of the relationship.\", is_rel=True\n",
    "        )\n",
    "\n",
    "    class DynamicGraph(BaseModel):\n",
    "        \"\"\"Represents a graph document consisting of nodes and relationships.\"\"\"\n",
    "\n",
    "        relationships: Optional[List[SimpleRelationship]] = Field(\n",
    "            description=\"List of relationships\"\n",
    "        )\n",
    "\n",
    "    return DynamicGraph\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "7e4c1cf7-357e-4db5-8616-fb6aabeec87a",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/tomazbratanic/anaconda3/lib/python3.11/site-packages/pandas/core/arrays/masked.py:60: UserWarning: Pandas requires version '1.3.6' or newer of 'bottleneck' (version '1.3.5' currently installed).\n",
      "  from pandas.core import (\n"
     ]
    }
   ],
   "source": [
    "from langchain.pydantic_v1 import BaseModel\n",
    "from langchain_google_vertexai import ChatVertexAI\n",
    "from langchain_core.prompts import ChatPromptTemplate\n",
    "from typing import List\n",
    "from langchain_core.pydantic_v1 import BaseModel, Field\n",
    "\n",
    "schema = create_simple_model()\n",
    "\n",
    "llm = ChatVertexAI(model_name=\"gemini-pro\", temperature=0, convert_system_message_to_human=True)\n",
    "llm_transformer = llm.with_structured_output(schema)\n",
    "\n",
    "chain = default_prompt | llm_transformer\n",
    "\n",
    "text =\"\"\"\n",
    "Maria Salomea Skłodowska-Curie[a] (Polish: [ˈmarja salɔˈmɛa skwɔˈdɔfska kʲiˈri] ⓘ; née Skłodowska; 7 November 1867 – 4 July 1934), known simply as Marie Curie (/ˈkjʊəri/ KURE-ee,[1] French: [maʁi kyʁi]), was a Polish and naturalised-French physicist and chemist who conducted pioneering research on radioactivity. She was the first woman to win a Nobel Prize, the first person to win a Nobel Prize twice, and the only person to win a Nobel Prize in two scientific fields. Her husband, Pierre Curie, was a co-winner of her first Nobel Prize, making them the first-ever married couple to win the Nobel Prize and launching the Curie family legacy of five Nobel Prizes. She was, in 1906, the first woman to become a professor at the University of Paris.[2]\n",
    "\"\"\"\n",
    "\n",
    "llm_output = chain.invoke({\"input\": text})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "1f9e1266-f0e8-479b-add6-9c595be32d2a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DynamicGraph(relationships=[SimpleRelationship(start_node_id='Maria Salomea Skłodowska-Curie', start_node_type='Person', end_node_id='Nobel Prize', end_node_type='Award', type='WINNER'), SimpleRelationship(start_node_id='Maria Salomea Skłodowska-Curie', start_node_type='Person', end_node_id='University of Paris', end_node_type='Organization', type='PROFESSOR'), SimpleRelationship(start_node_id='Maria Salomea Skłodowska-Curie', start_node_type='Person', end_node_id='Pierre Curie', end_node_type='Person', type='SPOUSE'), SimpleRelationship(start_node_id='Pierre Curie', start_node_type='Person', end_node_id='Nobel Prize', end_node_type='Award', type='WINNER')])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "llm_output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d43831a0-a671-4bc9-ba22-f017b4969862",
   "metadata": {},
   "outputs": [],
   "source": [
    "def map_to_base_relationship(rel: Any) -> Relationship:\n",
    "    \"\"\"Map the SimpleRelationship to the base Relationship.\"\"\"\n",
    "    source = Node(id=rel.start_node_id.title(), type=rel.start_node_type.capitalize())\n",
    "    target = Node(id=rel.end_node_id.title(), type=rel.end_node_type.capitalize())\n",
    "    return Relationship(\n",
    "        source=source, target=target, type=rel.type.replace(\" \", \"_\").upper()\n",
    "    )\n",
    "\n",
    "relationships = (\n",
    "    [map_to_base_relationship(rel) for rel in llm_output.relationships]\n",
    "    if llm_output.relationships\n",
    "    else []\n",
    ")\n",
    "\n",
    "graph_documents = [GraphDocument(nodes=[], relationships=relationships, source=Document(page_content=text))]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "b67e8d77-9f45-4ce7-b897-b663f1d655f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from langchain_community.graphs import Neo4jGraph\n",
    "\n",
    "os.environ[\"NEO4J_URI\"] = \"bolt://localhost:7687\"\n",
    "os.environ[\"NEO4J_USERNAME\"] = \"neo4j\"\n",
    "os.environ[\"NEO4J_PASSWORD\"] = \"password\"\n",
    "\n",
    "graph = Neo4jGraph()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "3de6d65c-94fb-4563-89e1-368faae75292",
   "metadata": {},
   "outputs": [],
   "source": [
    "graph.add_graph_documents(graph_documents)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "133dc477-54c7-4b25-aba2-0f405e5fb8d0",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
